package mysql

import (
	"io"
	"text/template"

	"gitee.com/lipore/plume/db_gorm/gen/code"
	"gitee.com/lipore/plume/db_gorm/gen/spec"
)

type createMethodData struct {
	repositoryMethodData
	Entity       code.Type
	DomainObject code.Param
}

var createMethodTemplate = `
{{block "repositoryMethod" .}}{{end}} {
	entity := &{{typeToCode .Entity}}{}
	entity.fromDomainObject({{.DomainObject.Name}})
    conn, err := db_gorm.LoadTx(ctx)
    if err != nil {
        return nil, err
    }
	err = conn.Transaction(func(tx *gorm.DB) error {
		return tx.Create(&entity).Error
	})
	if err != nil {
		return nil, err
	}
	return entity.toDomainObject(), nil
}`

var multiCreateMethodTemplate = `
{{block "repositoryMethod" .}}{{end}} {
	entities := make([]{{.Entity}}, len({{.DomainObject.Name}}))
    conn, err := db_gorm.LoadTx(ctx)
    if err != nil {
        return nil, err
    }
	var dos {{typeToCode .DomainObject.Type}}
	err = conn.Transaction(func(tx *gorm.DB) error {
		for i, d := range {{.DomainObject.Name}} {
			entities[i].fromDomainObject(d)
		}
		dos = make({{typeToCode .DomainObject.Type}}, len(entities))
		for i, e := range entities{
			err := conn.Create(&e).Error
			if err != nil {
				return  err
			}
			dos[i] = e.toDomainObject()
		}
		return err
	})
	if err != nil {
		return nil, err
	}
	return dos, nil
}`

func renderCreateMethod(s spec.MethodSpec, entity code.Type, repository code.Type, operation spec.InsertOperation, writer io.Writer) error {
	tmpl := template.New("create Method")
	tmpl, err := renderRepositoryMethod(tmpl)
	if err != nil {
		return err
	}

	if operation.Mode == spec.QueryModeMany {
		tmpl, err = tmpl.Parse(multiCreateMethodTemplate)
	} else {
		tmpl, err = tmpl.Parse(createMethodTemplate)
	}
	if err != nil {
		return err
	}

	tmplData := createMethodData{
		repositoryMethodData: repositoryMethodData{Receiver: code.Param{Name: "repository", Type: code.PointerType{ContainedType: repository}}, Spec: s},
		Entity:               entity,
		DomainObject:         s.Params[1],
	}

	return tmpl.Execute(writer, tmplData)
}
